## ============================================================================
## 08_si_s2_audience_iv.R
## ----------------------------------------------------------------------------
## Purpose:   Instrumental variables estimation of local average treatment
##            effects (LATEs) and nonparametric ATE bounds using the
##            methods of Kennedy (2020).
##            Corresponds to SI Section S2.2.2 (IV subsection).
##
## Requires:  06_si_s2_audience_effects.R must be run first
##            (reads pooled ATEs CSV).
##
## Inputs:    combined_data.rds (via 00_setup.R),
##            ates_study1_pooled_unadj.rds (from 06)
## Outputs:   Tables/table_s29.tex,
##            Tables/table_s30.tex
## ============================================================================

source("00_setup.R")
library('npcausal')

## ---- Table S29: Compliers and non-compliers by sample ------

## Helper: compute proportions and SEs for a given subset
compute_complier_props <- function(df) {
  d0 <- df %>% filter(Z_land_info == 0) %>% pull(y_recall_binary)
  d1 <- df %>% filter(Z_land_info == 1) %>% pull(y_recall_binary)
  
  d0 <- d0[!is.na(d0)]
  d1 <- d1[!is.na(d1)]
  
  n0 <- length(d0)
  n1 <- length(d1)
  
  ## Always-takers: P(D = 1 | Z = 0)
  p_AT <- mean(d0 == 1)
  se_AT <- sqrt(p_AT * (1 - p_AT) / n0)
  
  ## Never-takers: P(D = 0 | Z = 1)
  p_NT <- mean(d1 == 0)
  se_NT <- sqrt(p_NT * (1 - p_NT) / n1)
  
  ## Compliers: P(D = 1 | Z = 1) - P(D = 1 | Z = 0)
  p1 <- mean(d1)
  p0 <- mean(d0)
  p_C <- p1 - p0
  se_C <- sqrt(p1 * (1 - p1) / n1 + p0 * (1 - p0) / n0)
  
  data.frame(
    type = c("Complier", "Never-taker", "Always-taker"),
    prop = c(p_C, p_NT, p_AT),
    se   = c(se_C, se_NT, se_AT)
  )
}

complier_dat <-
  bind_rows(
    compute_complier_props(combined_dat %>% filter(!is.na(Z_land_info))) %>%
      mutate(group = "Combined"),
    compute_complier_props(combined_dat %>% filter(sample == "AU", !is.na(Z_land_info))) %>%
      mutate(group = "Australia"),
    compute_complier_props(combined_dat %>% filter(sample == "US", !is.na(Z_land_info))) %>%
      mutate(group = "United States")
  ) %>%
  mutate(
    type  = factor(type, levels = c(
      "Complier", "Never-taker", "Always-taker"
    )),
    group = factor(group, levels = c("Australia", "United States", "Combined")),
    entry = paste0(format_num(prop), " ", add_parens(se))
  ) %>%
  select(type, group, entry) %>%
  pivot_wider(names_from = group, values_from = entry) %>%
  arrange(type) %>% 
  select(type, Australia, `United States`, Combined)

## Export for latex table
complier_dat %>%
  mutate(type = paste0("\\textit{", type, "}")) %>%
  column_to_rownames("type") %>% 
  xtable() %>% 
  print(
    include.rownames = TRUE,
    include.colnames = FALSE,
    only.contents = TRUE,
    type = "latex",
    sanitize.text.function = identity,
    comment = FALSE,
    hline.after = NULL,
    file = paste0(tab_dir, "table_s29.tex")
  )

## ---- In-text estimates for SI S2.2 IV section -----                                                                                                              

## Difference in Always-taker proportions: AU vs US                                                                                                                 
diff_AT <- 0.70 - 0.58
se_diff_AT <- sqrt(0.02^2 + 0.02^2)                                                                                                                                 
pval_diff_AT <- 2 * (1 - pnorm(abs(diff_AT / se_diff_AT)))                       

## Difference in Never-taker proportions: US vs AU
diff_NT <- 0.22 - 0.08
se_diff_NT <- sqrt(0.01^2 + 0.01^2)
pval_diff_NT <- 2 * (1 - pnorm(abs(diff_NT / se_diff_NT)))

## Difference in Complier proportions: AU vs US
diff_C <- 0.21 - 0.20
se_diff_C <- sqrt(0.02^2 + 0.02^2)
pval_diff_C <- 2 * (1 - pnorm(abs(diff_C / se_diff_C)))

cat("Always-takers (AU - US):", format_num(diff_AT),
    "se =", format_num(se_diff_AT), "P =", format_num(pval_diff_AT, 3), "\n")
cat("Never-takers (US - AU):", format_num(diff_NT),
    "se =", format_num(se_diff_NT), "P =", format_num(pval_diff_NT, 3), "\n")
cat("Compliers (AU - US):", format_num(diff_C),
    "se =", format_num(se_diff_C), "P =", format_num(pval_diff_C, 2), "\n")

## First-stage: effect of Z on D (combined sample)
first_stage <- lm_robust(y_recall_binary ~ Z_land_info,
                         data = combined_dat %>% filter(!is.na(Z_land_info)))
cat("First stage:", format_num(first_stage$coefficients["Z_land_info"]),
    "se =", format_num(first_stage$std.error["Z_land_info"]),
    "F =", round(first_stage$fstatistic[1], 0), 
    "P = ", format_num(broom::glance(first_stage)$p.value, 2), "\n")

## ---- Estimation using npcausal (slow) ------
covariates <-
  c(
    "x_edu_college",
    "x_employ_binary",
    "x_gender_f",
    "x_age_cat",
    "x_hhi",
    "x_native",
    "x_pid_comb2",
    "sample"
  )

outcomes <-
  c(
    "y_symbolic_idx_s",
    "y_substantive_idx_s",
    "y_cbr_idx_s",
    "y_guilt_idx_s",
    "y_infoseek_s",
    "y_acknowledge_s"
  )

est_dat <-
  combined_dat %>%
  mutate(
    ## Note that we need to center and scale indices before standardizing or
    ## the function will give nonsense estimates for ATE bounds
    y_guilt_idx_s = glass_delta(as.numeric(scale(y_guilt_idx)), Z = Z_land_info, reference = 0),
    y_cbr_idx_s = glass_delta(as.numeric(scale(y_cbr_idx)), Z = Z_land_info, reference = 0),
    y_substantive_idx_s = glass_delta(as.numeric(scale(y_substantive_idx)), Z = Z_land_info, reference = 0),
    y_symbolic_idx_s = glass_delta(as.numeric(scale(y_symbolic_idx)), Z = Z_land_info, reference = 0),
    y_voice_vote_s = glass_delta(y_voice_vote, Z = Z_land_info, reference = 0),
    y_infoseek_s = glass_delta(y_infoseek_n, Z = Z_land_info, reference = 0),
    y_acknowledge_s = glass_delta(y_acknowledge_n, Z = Z_land_info, reference = 0),
  ) %>%
  filter(!is.na(Z_land_info)) %>%
  select(all_of(covariates),
         Z_land_info,
         y_recall_binary,
         all_of(outcomes))

index <- complete.cases(est_dat)

D <- as.numeric(est_dat$y_recall_binary[index])
Z <- as.numeric(est_dat$Z_land_info[index])

X <- est_dat[index, c(covariates)]

## Recode us_covs to numeric
dummies <- model.matrix(~ x_age_cat - 1, data = X)
X <- cbind(X, dummies)

dummies <- model.matrix(~ x_hhi - 1, data = X)
X <- cbind(X, dummies)

dummies <- model.matrix(~ x_pid_comb2 - 1, data = X)
X <- cbind(X, dummies)

dummies <- model.matrix(~ sample - 1, data = X)
X <- cbind(X, dummies)

## Keep numeric covariates for GRF:
nums <- unlist(lapply(X, is.numeric))
X <- as.matrix(X[, nums])

## For loop for all outcomes
set.seed(123)
ivlate_out <- list()
ivbds_out <- list()
gamma_dat <- list()
N <- length(D)

for(i in 1:length(outcomes)) {
  Y <- as.matrix(est_dat[index, outcomes[i]])[, 1]

  ## Estimate mean outcome among compliers had all subjects received treatment
  ## versus control
  ivlate_est <- ivlate(
    y = Y,
    a = D,
    z = Z,
    x = cbind(1, X),
    ## Set larger splits (See Remark 3.1 in https://doi.org/10.1111/ectj.12097)
    nsplits = 6,
    sl.lib = c("SL.glm")
  )

  ## Estimate bounds on treatment effects using instrumental variables
  ivbds_est <- ivbds(
    y = Y,
    a = D,
    z = Z,
    x = cbind(1, X),
    nsplits = 6,
    sl.lib = c("SL.glm")
  )

  ivlate_out[[i]] <- cbind(ivlate_est$res, outcome = rep(outcomes[i], 3))
  ivbds_out[[i]] <- cbind(ivbds_est$res, outcome = rep(outcomes[i], 3))

}

## Repeat for Voice outcome (AU only)
covariates <-
  c(
    "x_edu_college",
    "x_employ_binary",
    "x_gender_f",
    "x_age_cat",
    "x_hhi",
    "x_native",
    "x_pid_comb2"
  )

est_dat <-
  combined_dat %>%
  filter(!is.na(Z_land_info)) %>%
  select(all_of(covariates), Z_land_info, y_recall_binary, "y_voice_vote_s")

index <- complete.cases(est_dat)

X <- est_dat[index, c(covariates)]

## Recode us_covs to numeric
dummies <- model.matrix(~ x_age_cat - 1, data = X)
X <- cbind(X, dummies)

dummies <- model.matrix(~ x_hhi - 1, data = X)
X <- cbind(X, dummies)

dummies <- model.matrix(~ x_pid_comb2 - 1, data = X)
X <- cbind(X, dummies)


## Keep numeric covariates
nums <- unlist(lapply(X, is.numeric))
X <- as.matrix(X[, nums])


D <- as.numeric(est_dat$y_recall_binary[index])
Z <- as.numeric(est_dat$Z_land_info[index])
Y <- as.matrix(est_dat[index, "y_voice_vote_s"])[, 1]

set.seed(123)
ivlate_voice <- ivlate(
  y = Y,
  a = D,
  z = Z,
  x = cbind(1, X),
  ## Set larger splits (See Remark 3.1 in https://doi.org/10.1111/ectj.12097)
  nsplits = 6,
  sl.lib = c("SL.glm")
)


## Estimate bounds on treatment effects using instrumental variables
ivbds_voice <- ivbds(
  y = Y,
  a = D,
  z = Z,
  x = cbind(1, X),
  nsplits = 6,
  sl.lib = c("SL.glm")
)

ivlate_combined <-
  bind_rows(ivlate_out, ivlate_voice$res) %>%
  mutate(outcome = case_when(is.na(outcome) ~ "y_voice_vote_s", .default = outcome))

ivbds_combined <-
  bind_rows(ivbds_out, ivbds_voice$res) %>%
  mutate(outcome = case_when(is.na(outcome) ~ "y_voice_vote_s", .default = outcome))

## Estimated instrument sharpness is ~ 0, so compliance is not predicted by
## covariates. Hence, we cannot generalize the LATEs to a well-defined subgroup
## of respondents defined by covariates
ivlate_combined %>% filter(parameter == "Sharpness") %>% pull(est) %>% mean()

## ---- Table S30: LATEs and ATE bounds -----
est_late <-
  ivlate_combined %>%
  filter(parameter == "LATE") %>%
  mutate(
    ## Add CIs adjusting for multiple comparisons
    ci.ll95_adj = est - qnorm(1 - (0.05 / 7) / 2)*se,
    ci.ul95_adj = est + qnorm(1 - (0.05 / 7) / 2)*se,
    ci.ll90_adj = est - qnorm(1 - (0.10 / 7) / 2)*se,
    ci.ul90_adj = est + qnorm(1 - (0.10 / 7) / 2)*se,
    outcome = factor(outcome, levels = c("y_acknowledge_s",
                                             "y_infoseek_s",
                                             "y_guilt_idx_s",
                                             "y_cbr_idx_s",
                                             "y_symbolic_idx_s",
                                             "y_substantive_idx_s",
                                             "y_voice_vote_s"),
                     labels = c("Indigenous Land",
                                    "Information Seeking",
                                    "Collective Guilt",
                                    "Colorblind Racism",
                                    "Symbolic Reparations",
                                    "Substantive Reparations",
                                    "Voice Referendum"
                     )),
    pval_adj = p.adjust(pval, method = "bonferroni")
  ) %>%
  select(outcome, est, se, ci.ll95_adj, ci.ul95_adj, ci.ll90_adj, ci.ul90_adj,
         pval_adj) %>%
  arrange(outcome)

est_ate <-
  ivbds_combined %>%
  filter(parameter == "ATE") %>%
  mutate(
    outcome = factor(outcome, levels = c("y_acknowledge_s",
                                             "y_infoseek_s",
                                             "y_guilt_idx_s",
                                             "y_cbr_idx_s",
                                             "y_symbolic_idx_s",
                                             "y_substantive_idx_s",
                                             "y_voice_vote_s"),
                     labels = c("Indigenous Land",
                                    "Information Seeking",
                                    "Collective Guilt",
                                    "Colorblind Racism",
                                    "Symbolic Reparations",
                                    "Substantive Reparations",
                                    "Voice Referendum"
                     ))
  )


est_tab <-
  est_late %>%
  select(outcome,
         est,
         se,
         ci.ll95_adj,
         ci.ul95_adj,
         ci.ll90_adj,
         ci.ul90_adj,
         pval_adj) %>%
  mutate(
    late = table_entry(est = est, se = se),
    late_ci95 = paste0("[", sprintf("%.2f", round(ci.ll95_adj, 2)), ", ", sprintf("%.2f", round(ci.ul95_adj, 2)), "]"),
    late_ci90 = paste0("[", sprintf("%.2f", round(ci.ll90_adj, 2)), ", ", sprintf("%.2f", round(ci.ul90_adj, 2)), "]"),
    late_p = case_when(
      pval_adj < 0.001 ~ "< 0.001",
      round(pval_adj, 3) == 0.001 ~ "0.001",
      pval_adj > 0.99 ~ "> 0.99",
      .default = paste(round(pval_adj, 3))
    )
  ) %>%
  select(outcome, late, late_ci95, late_ci90, late_p) %>%
  left_join(est_ate %>%
              mutate(ate_bounds = paste0(
                "(", sprintf("%.2f", round(lb, 2)), ", ", sprintf("%.2f", round(ub, 2)), ")"
              )) %>%
              select(outcome, ate_bounds), by = "outcome") %>%
  mutate(across(where(is.numeric), \(x) sprintf("%.2f", round(x, 2)))) %>%
  arrange(outcome)

## Combined with ITTs
itt_dat <- read_rds("ates_study1_pooled_unadj.rds")

outcome_levs <- levels(est_tab$outcome)

latex_tab <-
  itt_dat %>%
  select(outcome_group, estimate, std.error, pval_adj) %>%
  rename(outcome = outcome_group) %>%
  filter(outcome != "Indigenous Awareness") %>%
  mutate(itt = table_entry(est = estimate, se = std.error)) %>%
  select(outcome, itt) %>%
  left_join(est_tab, by = "outcome") %>%
  mutate(outcome = factor(outcome, levels = outcome_levs)) %>%
  arrange(outcome)

latex_tab %>%
  mutate(
    outcome = paste0("\\textit{", outcome, "}")  # Italicize text
  ) %>%
  select(outcome, itt, late, late_ci95, late_ci90, late_p, ate_bounds) %>%
  xtable() %>%
  print(
    include.rownames = FALSE,
    include.colnames = FALSE,
    only.contents = TRUE,
    type = "latex",
    sanitize.text.function = identity,
    comment = FALSE,
    hline.after = NULL, # Removes all default rules (prevents \midrule)
    file = paste0(tab_dir, "table_s30.tex")
  )

## ---- Session Info -----------------------------------------------------------
sessionInfo()
